{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 垃圾邮件分类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "SPAM_PATH = os.path.join('datasets', 'spam')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取邮件名\n",
    "HAM_DIR = os.path.join(SPAM_PATH, 'easy_ham')\n",
    "SPAM_DIR = os.path.join(SPAM_PATH, 'spam')\n",
    "ham_filenames = [name for name in sorted(os.listdir(HAM_DIR)) if len(name) > 20]\n",
    "spam_filenames = [name for name in sorted(os.listdir(SPAM_DIR)) if len(name) > 20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用python的email模块解析这些电子邮件（它处理邮件头、编码等）\n",
    "import email\n",
    "import email.policy\n",
    "\n",
    "def load_email(is_spam, filename, spam_path = SPAM_PATH):\n",
    "    directory = 'spam' if is_spam else 'easy_ham'\n",
    "    with open(os.path.join(spam_path, directory, filename), 'rb') as f:\n",
    "        return email.parser.BytesParser(policy = email.policy.default).parse(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Martin A posted:\n",
      "Tassos Papadopoulos, the Greek sculptor behind the plan, judged that the\n",
      " limestone of Mount Kerdylio, 70 miles east of Salonika and not far from the\n",
      " Mount Athos monastic community, was ideal for the patriotic sculpture. \n",
      " \n",
      " As well as Alexander's granite features, 240 ft high and 170 ft wide, a\n",
      " museum, a restored amphitheatre and car park for admiring crowds are\n",
      "planned\n",
      "---------------------\n",
      "So is this mountain limestone or granite?\n",
      "If it's limestone, it'll weather pretty fast.\n",
      "\n",
      "------------------------ Yahoo! Groups Sponsor ---------------------~-->\n",
      "4 DVDs Free +s&p Join Now\n",
      "http://us.click.yahoo.com/pt6YBB/NXiEAA/mG3HAA/7gSolB/TM\n",
      "---------------------------------------------------------------------~->\n",
      "\n",
      "To unsubscribe from this group, send an email to:\n",
      "forteana-unsubscribe@egroups.com\n",
      "\n",
      " \n",
      "\n",
      "Your use of Yahoo! Groups is subject to http://docs.yahoo.com/info/terms/\n"
     ]
    }
   ],
   "source": [
    "# 看一个ham示例和一个spam示例，了解数据的外观：\n",
    "ham_emails = [load_email(is_spam = False, filename = name) for name in ham_filenames]\n",
    "spam_emails = [load_email(is_spam = True, filename = name) for name in spam_filenames]\n",
    "print(ham_emails[1].get_content().strip())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Help wanted.  We are a 14 year old fortune 500 company, that is\n",
      "growing at a tremendous rate.  We are looking for individuals who\n",
      "want to work from home.\n",
      "\n",
      "This is an opportunity to make an excellent income.  No experience\n",
      "is required.  We will train you.\n",
      "\n",
      "So if you are looking to be employed from home with a career that has\n",
      "vast opportunities, then go:\n",
      "\n",
      "http://www.basetel.com/wealthnow\n",
      "\n",
      "We are looking for energetic and self motivated people.  If that is you\n",
      "than click on the link and fill out the form, and one of our\n",
      "employement specialist will contact you.\n",
      "\n",
      "To be removed from our link simple go to:\n",
      "\n",
      "http://www.basetel.com/remove.html\n",
      "\n",
      "\n",
      "4139vOLW7-758DoDY1425FRhM1-764SMFc8513fCsLl40\n"
     ]
    }
   ],
   "source": [
    "print(spam_emails[6].get_content().strip())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 电子邮件有很多部分，带有图像和附件（它们可以有自己的附件）。查看邮件的各种类型的结构：\n",
    "def get_email_structure(email):\n",
    "    if isinstance(email, str):\n",
    "        return email\n",
    "    payload = email.get_payload()\n",
    "    if isinstance(payload, list):\n",
    "        return 'multipart({})'.format(', '.join([\n",
    "            get_email_structure(sub_email)\n",
    "            for sub_email in payload\n",
    "        ]))\n",
    "    else:\n",
    "        return email.get_content_type()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({1: 1, 4: 2, 2: 3, 3: 2})"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from collections import Counter\n",
    "a = [1,4,2,3,2,3,4,2]  \n",
    " \n",
    "b = Counter(a) #求数组中每个数字出现了几次\n",
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "\n",
    "def structures_counter(emails):\n",
    "    structures = Counter()\n",
    "    for email in emails:\n",
    "        structure = get_email_structure(email)\n",
    "        structures[structure] += 1\n",
    "    return structures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('text/plain', 2408),\n",
       " ('multipart(text/plain, application/pgp-signature)', 66),\n",
       " ('multipart(text/plain, text/html)', 8),\n",
       " ('multipart(text/plain, text/plain)', 4),\n",
       " ('multipart(text/plain)', 3),\n",
       " ('multipart(text/plain, application/octet-stream)', 2),\n",
       " ('multipart(text/plain, text/enriched)', 1),\n",
       " ('multipart(text/plain, application/ms-tnef, text/plain)', 1),\n",
       " ('multipart(multipart(text/plain, text/plain, text/plain), application/pgp-signature)',\n",
       "  1),\n",
       " ('multipart(text/plain, video/mng)', 1),\n",
       " ('multipart(text/plain, multipart(text/plain))', 1),\n",
       " ('multipart(text/plain, application/x-pkcs7-signature)', 1),\n",
       " ('multipart(text/plain, multipart(text/plain, text/plain), text/rfc822-headers)',\n",
       "  1),\n",
       " ('multipart(text/plain, multipart(text/plain, text/plain), multipart(multipart(text/plain, application/x-pkcs7-signature)))',\n",
       "  1),\n",
       " ('multipart(text/plain, application/x-java-applet)', 1)]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "structures_counter(ham_emails).most_common()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('text/plain', 218),\n",
       " ('text/html', 183),\n",
       " ('multipart(text/plain, text/html)', 45),\n",
       " ('multipart(text/html)', 20),\n",
       " ('multipart(text/plain)', 19),\n",
       " ('multipart(multipart(text/html))', 5),\n",
       " ('multipart(text/plain, image/jpeg)', 3),\n",
       " ('multipart(text/html, application/octet-stream)', 2),\n",
       " ('multipart(text/plain, application/octet-stream)', 1),\n",
       " ('multipart(text/html, text/plain)', 1),\n",
       " ('multipart(multipart(text/html), application/octet-stream, image/jpeg)', 1),\n",
       " ('multipart(multipart(text/plain, text/html), image/gif)', 1),\n",
       " ('multipart/alternative', 1)]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "structures_counter(spam_emails).most_common()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Return-Path : <12a1mailbot1@web.de>\n",
      "Delivered-To : zzzz@localhost.spamassassin.taint.org\n",
      "Received : from localhost (localhost [127.0.0.1])\tby phobos.labs.spamassassin.taint.org (Postfix) with ESMTP id 136B943C32\tfor <zzzz@localhost>; Thu, 22 Aug 2002 08:17:21 -0400 (EDT)\n",
      "Received : from mail.webnote.net [193.120.211.219]\tby localhost with POP3 (fetchmail-5.9.0)\tfor zzzz@localhost (single-drop); Thu, 22 Aug 2002 13:17:21 +0100 (IST)\n",
      "Received : from dd_it7 ([210.97.77.167])\tby webnote.net (8.9.3/8.9.3) with ESMTP id NAA04623\tfor <zzzz@spamassassin.taint.org>; Thu, 22 Aug 2002 13:09:41 +0100\n",
      "From : 12a1mailbot1@web.de\n",
      "Received : from r-smtp.korea.com - 203.122.2.197 by dd_it7  with Microsoft SMTPSVC(5.5.1775.675.6);\t Sat, 24 Aug 2002 09:42:10 +0900\n",
      "To : dcek1a1@netsgo.com\n",
      "Subject : Life Insurance - Why Pay More?\n",
      "Date : Wed, 21 Aug 2002 20:31:57 -1600\n",
      "MIME-Version : 1.0\n",
      "Message-ID : <0103c1042001882DD_IT7@dd_it7>\n",
      "Content-Type : text/html; charset=\"iso-8859-1\"\n",
      "Content-Transfer-Encoding : quoted-printable\n"
     ]
    }
   ],
   "source": [
    "# 正常邮件更多的是纯文本，而垃圾邮件有相当多的HTML。\n",
    "# 查看邮件头\n",
    "for header, value in spam_emails[0].items():\n",
    "    print(header,':', value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Life Insurance - Why Pay More?'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 里面可能有很多有用的信息，比如发件人的电子邮件地址（12a1mailbot1@web.de看起来很可疑），\n",
    "# 查看“主题”标题：\n",
    "\n",
    "spam_emails[0]['Subject']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 拆分训练集和测试集合"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\stefa\\appdata\\local\\programs\\python\\python37\\lib\\site-packages\\ipykernel_launcher.py:4: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
      "  after removing the cwd from sys.path.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "X = np.array(ham_emails + spam_emails)\n",
    "y = np.array([0] * len(ham_emails) + [1] * len(spam_emails))\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "   * 首先需要一个函数来将html转换为纯文本，使用Beautifulsoup库，\n",
    "   * 下面的函数首先删除<head>部分，然后将所有<a>标记转换为单词hyperlink，然后去掉所有html标记，只留下纯文本。\n",
    "    为了可读性，它还用一个换行符替换多个换行符，最后它取消了HTML实体（例如`&gt；`或`&nbsp；`）\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from html import unescape\n",
    "\n",
    "def html_to_plain_text(html):\n",
    "    text = re.sub('<head.*?>.*?</head>', '', html, flags = re.M | re.S | re.I)\n",
    "    text = re.sub('<a\\s.*?>', ' HYPERLINK ', text, flags = re.M | re.S | re.I)\n",
    "    text = re.sub('<.*?>', '', text, flags = re.M | re.S)\n",
    "    text = re.sub(r'(\\s*\\n) + ', '\\n', text, flags = re.M | re.S)\n",
    "    return unescape(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<HTML><HEAD><TITLE></TITLE><META http-equiv=\"Content-Type\" content=\"text/html; charset=windows-1252\"><STYLE>A:link {TEX-DECORATION: none}A:active {TEXT-DECORATION: none}A:visited {TEXT-DECORATION: none}A:hover {COLOR: #0033ff; TEXT-DECORATION: underline}</STYLE><META content=\"MSHTML 6.00.2713.1100\" name=\"GENERATOR\"></HEAD>\n",
      "<BODY text=\"#000000\" vLink=\"#0033ff\" link=\"#0033ff\" bgColor=\"#CCCC99\"><TABLE borderColor=\"#660000\" cellSpacing=\"0\" cellPadding=\"0\" border=\"0\" width=\"100%\"><TR><TD bgColor=\"#CCCC99\" valign=\"top\" colspan=\"2\" height=\"27\">\n",
      "<font size=\"6\" face=\"Arial, Helvetica, sans-serif\" color=\"#660000\">\n",
      "<b>OTC</b></font></TD></TR><TR><TD height=\"2\" bgcolor=\"#6a694f\">\n",
      "<font size=\"5\" face=\"Times New Roman, Times, serif\" color=\"#FFFFFF\">\n",
      "<b>&nbsp;Newsletter</b></font></TD><TD height=\"2\" bgcolor=\"#6a694f\"><div align=\"right\"><font color=\"#FFFFFF\">\n",
      "<b>Discover Tomorrow's Winners&nbsp;</b></font></div></TD></TR><TR><TD height=\"25\" colspan=\"2\" bgcolor=\"#CCCC99\"><table width=\"100%\" border=\"0\"  ...\n"
     ]
    }
   ],
   "source": [
    "html_spam_emails = [email for email in X_train[y_train == 1]\n",
    "                    if get_email_structure(email) == 'text/html']\n",
    "sample_html_spam = html_spam_emails[7]\n",
    "print(sample_html_spam.get_content().strip()[:1000], '...')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "OTC\n",
      "\n",
      " Newsletter\n",
      "Discover Tomorrow's Winners \n",
      "\n",
      "For Immediate Release\n",
      "\n",
      "Cal-Bay (Stock Symbol: CBYI)\n",
      "Watch for analyst \"Strong Buy Recommendations\" and several advisory newsletters picking CBYI.  CBYI has filed to be traded on the OTCBB, share prices historically INCREASE when companies get listed on this larger trading exchange. CBYI is trading around 25 cents and should skyrocket to $2.66 - $3.25 a share in the near future.\n",
      "Put CBYI on your watch list, acquire a position TODAY.\n",
      "\n",
      "REASONS TO INVEST IN CBYI\n",
      "\n",
      "A profitable company and is on track to beat ALL earnings estimates!\n",
      "\n",
      "One of the FASTEST growing distributors in environmental & safety equipment instruments.\n",
      "\n",
      "Excellent management team, several EXCLUSIVE contracts.  IMPRESSIVE client list including the U.S. Air Force, Anheuser-Busch, Chevron Refining and Mitsubishi Heavy Industries, GE-Energy & Environmental Research.\n",
      "\n",
      "RAPIDLY GROWING INDUSTRY\n",
      "Industry revenues exceed $900 million, estimates indicate that there could be as much as ...\n"
     ]
    }
   ],
   "source": [
    "print(html_to_plain_text(sample_html_spam.get_content())[:1000], '...')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 编写一个函数，它以电子邮件为输入，并以纯文本形式返回其内容，无论其格式是什么\n",
    "def email_to_text(email):\n",
    "    html = None\n",
    "    for part in email.walk():\n",
    "        ctype = part.get_content_type()\n",
    "        if not ctype in ('text/plain', 'text/html'):\n",
    "            continue\n",
    "        try:\n",
    "            content = part.get_content()\n",
    "        except: # 解决编码问题\n",
    "            content = str(part.get_payload())\n",
    "        if ctype == 'text/plain':\n",
    "            return content\n",
    "        else:\n",
    "            html = content\n",
    "    if html:\n",
    "        return html_to_plain_text(html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "OTC\n",
      "\n",
      " Newsletter\n",
      "Discover Tomorrow's Winners \n",
      "\n",
      "For Immediate Release\n",
      "\n",
      "Cal-Bay (Stock Symbol: CBYI ...\n"
     ]
    }
   ],
   "source": [
    "print(email_to_text(sample_html_spam)[:100], '...')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computations => comput\n",
      "Computation => comput\n",
      "Computing => comput\n",
      "Computed => comput\n",
      "Compute => comput\n",
      "Compulsive => compuls\n"
     ]
    }
   ],
   "source": [
    "# cmd下安装自然语言工具包（[nltk]（http://www.nltk.org/）\n",
    "# pip install nltk\n",
    "\n",
    "# 用“url”替换url的方法 \n",
    "# pip install urlextract\n",
    "import nltk\n",
    "from urlextract import URLExtract\n",
    "\n",
    "try:\n",
    "    import nltk\n",
    "\n",
    "    stemmer = nltk.PorterStemmer()\n",
    "    for word in ('Computations', 'Computation', 'Computing', 'Computed', 'Compute', 'Compulsive'):\n",
    "        print(word, '=>', stemmer.stem(word))\n",
    "except ImportError:\n",
    "    print('Error: stemming requires the NLTK module.')\n",
    "    stemmer = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "   * 将所有处理整合到一个转换器中，我们将使用它将电子邮件转换为文字计数器。\n",
    "   * 注意，我们使用python的split（）方法将句子拆分为单词，该方法使用空格作为单词边界。\n",
    "   * 但例如，汉语和日语脚本通常不在单词之间使用空格,在这个练习中没关系，因为数据集（主要）是英文的，中文可以使用结巴分词来进行拆分¶"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.base import BaseEstimator, TransformerMixin\n",
    "\n",
    "class EmailToWordCounterTransformer(BaseEstimator, TransformerMixin):\n",
    "    def __init__(self, strip_headers = True, lower_case = True, remove_punctuation = True,\n",
    "                 replace_urls = True, replace_numbers = True, stemming = True):\n",
    "        self.strip_headers = strip_headers\n",
    "        self.lower_case = lower_case\n",
    "        self.remove_punctuation = remove_punctuation\n",
    "        self.replace_urls = replace_urls\n",
    "        self.replace_numbers = replace_numbers\n",
    "        self.stemming = stemming\n",
    "        \n",
    "    def fit(self, X, y = None):\n",
    "        return self\n",
    "    def transform(self, X, y = None):\n",
    "        X_transformed = []\n",
    "        for email in X:\n",
    "            text = email_to_text(email) or ''\n",
    "            if self.lower_case:\n",
    "                text = text.lower()\n",
    "            if self.replace_urls:\n",
    "                extractor = URLExtract()\n",
    "                urls = list(set(extractor.find_urls(text)))\n",
    "                urls.sort(key = lambda url: len(url), reverse = True)\n",
    "                for url in urls:  # 替换url 为 ‘URL’\n",
    "                    text = text.replace(url, ' URL ')\n",
    "                    \n",
    "            if self.replace_numbers:  # 替换数字 正则\n",
    "                text = re.sub(r'\\d + (?:\\.\\d * (?:[eE]\\d + ))?', 'NUMBER', text)\n",
    "            if self.remove_punctuation:  # 删除标点符号\n",
    "                text = re.sub(r'\\W + ', ' ', text, flags = re.M)\n",
    "            word_counts = Counter(text.split())\n",
    "            if self.stemming and stemmer is not None:\n",
    "                stemmed_word_counts = Counter()\n",
    "                for word, count in word_counts.items():\n",
    "                    stemmed_word = stemmer.stem(word)\n",
    "                    stemmed_word_counts[stemmed_word] += count\n",
    "                word_counts = stemmed_word_counts\n",
    "                \n",
    "            X_transformed.append(word_counts)\n",
    "        return np.array(X_transformed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([Counter({'chuck': 1, 'murcko': 1, 'wrote:': 1, '>[...stuff...]': 1, 'yawn.': 1, 'r': 1}),\n",
       "       Counter({'the': 11, 'of': 9, 'and': 8, 'all': 3, 'to': 3, 'by': 3, 'have': 2, 'superstit': 2, 'one': 2, 'on': 2, 'been': 2, 'half': 2, 'teach': 2, 'some': 1, 'interest': 1, 'quotes...': 1, 'url': 1, 'thoma': 1, 'jefferson:': 1, '\"i': 1, 'examin': 1, 'known': 1, 'word,': 1, 'i': 1, 'do': 1, 'not': 1, 'find': 1, 'in': 1, 'our': 1, 'particular': 1, 'christian': 1, 'redeem': 1, 'feature.': 1, 'they': 1, 'are': 1, 'alik': 1, 'found': 1, 'fabl': 1, 'mythology.': 1, 'million': 1, 'innoc': 1, 'men,': 1, 'women': 1, 'children,': 1, 'sinc': 1, 'introduct': 1, 'christianity,': 1, 'burnt,': 1, 'tortured,': 1, 'fine': 1, 'imprisoned.': 1, 'what': 1, 'ha': 1, 'effect': 1, 'thi': 1, 'coercion?': 1, 'make': 1, 'world': 1, 'fool': 1, 'other': 1, 'hypocrites;': 1, 'support': 1, 'rogueri': 1, 'error': 1, 'over': 1, 'earth.\"': 1, 'six': 1, 'histor': 1, 'americans,': 1, 'john': 1, 'e.': 1, 'remsburg,': 1, 'letter': 1, 'william': 1, 'short': 1, 'jefferson': 1, 'again:': 1, '\"christianity...(ha': 1, 'become)': 1, 'most': 1, 'pervert': 1, 'system': 1, 'that': 1, 'ever': 1, 'shone': 1, 'man.': 1, '...rogueries,': 1, 'absurd': 1, 'untruth': 1, 'were': 1, 'perpetr': 1, 'upon': 1, 'jesu': 1, 'a': 1, 'larg': 1, 'band': 1, 'dupe': 1, 'import': 1, 'led': 1, 'paul,': 1, 'first': 1, 'great': 1, 'corrupt': 1, 'jesus.\"': 1}),\n",
       "       Counter({'>': 4, 'url': 4, 'in': 2, 'an': 2, 'and': 2, 'yahoo!': 2, 'group': 2, 'to': 2, '---': 1, 'forteana@y...,': 1, '\"martin': 1, 'adamson\"': 1, '<martin@s...>': 1, 'wrote:': 1, 'for': 1, 'alternative,': 1, 'rather': 1, 'more': 1, 'factual': 1, 'based,': 1, 'rundown': 1, 'on': 1, \"hamza'\": 1, 'career,': 1, 'includ': 1, 'hi': 1, 'belief': 1, 'that': 1, 'all': 1, 'non': 1, 'muslim': 1, 'yemen': 1, 'should': 1, 'be': 1, 'murder': 1, 'outright:': 1, 'we': 1, 'know': 1, 'how': 1, 'unbias': 1, 'memri': 1, 'is,': 1, \"don't\": 1, 'we....': 1, 'html': 1, 'rob': 1, '------------------------': 1, 'sponsor': 1, '---------------------~-->': 1, '4': 1, 'dvd': 1, 'free': 1, '+s&p': 1, 'join': 1, 'now': 1, '---------------------------------------------------------------------~->': 1, 'unsubscrib': 1, 'from': 1, 'thi': 1, 'group,': 1, 'send': 1, 'email': 1, 'to:': 1, 'forteana-unsubscribe@egroups.com': 1, 'your': 1, 'use': 1, 'of': 1, 'is': 1, 'subject': 1})],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 在一些邮件上 测试 转换器\n",
    "X_few = X_train[:3]\n",
    "X_few_wordcounts = EmailToWordCounterTransformer().fit_transform(X_few)\n",
    "X_few_wordcounts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "   * 有了单词计数，我们需要把它们转换成向量。\n",
    "   * 为此，我们将构建另一个转换器，其fit（）方法将构建词汇表（最常用单词的有序列表），\n",
    "   * 其transform（）方法将使用词汇表将单词计数转换为向量--稀疏矩阵"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.sparse import csr_matrix\n",
    "\n",
    "class WordCounterToVectorTransformer(BaseEstimator, TransformerMixin):\n",
    "    def __init__(self, vocabulary_size = 1000):\n",
    "        self.vocabulary_size = vocabulary_size  # 词汇量\n",
    "        \n",
    "    def fit(self, X, y = None):\n",
    "        total_count = Counter()\n",
    "        for word_count in X:\n",
    "            for word, count in word_count.items():\n",
    "                total_count[word] += min(count, 10)  # 10设置数字上限\n",
    "        most_common = total_count.most_common()[:self.vocabulary_size]\n",
    "        self.most_common_ = most_common\n",
    "        self.vocabulary_ = {word: index + 1 for index, (word, count) in enumerate(most_common)}\n",
    "        return self\n",
    "    \n",
    "    def transform(self, X, y = None):\n",
    "        rows = []\n",
    "        cols = []\n",
    "        data = []\n",
    "        for row, word_count in enumerate(X):\n",
    "            for word, count in word_count.items():\n",
    "                rows.append(row) # 训练集 实例个数 邮件索引\n",
    "                cols.append(self.vocabulary_.get(word, 0)) # 取得单词在词汇表中的索引位置，0代表未出现在词汇表中\n",
    "                data.append(count)\n",
    "        return csr_matrix((data, (rows, cols)), shape = (len(X), self.vocabulary_size + 1)) \n",
    "        # 输出稀疏矩阵 +1因为第一列要显示未出现在词汇表中的单词统计数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from scipy.sparse import *\n",
    " \n",
    "# row =  [0,0,0,1,1,1,2,2,2]   #行指标\n",
    "# col =  [0,1,2,0,1,2,0,1,2]   #列指标\n",
    "# data = [1,0,1,0,1,1,1,1,0]   #在行指标列指标下的数字\n",
    "# team = csr_matrix((data,(row,col)), shape = (3,3))\n",
    "# print(team)\n",
    "# print(team.todense())\n",
    "# team.toarray()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 11)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab_transformer = WordCounterToVectorTransformer(vocabulary_size = 10)\n",
    "X_few_vectors = vocab_transformer.fit_transform(X_few_wordcounts)\n",
    "X_few_vectors.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[  6,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],\n",
       "       [101,  11,   9,   8,   1,   3,   3,   0,   1,   2,   3],\n",
       "       [ 64,   0,   1,   2,   4,   2,   1,   4,   2,   1,   0]],\n",
       "      dtype=int32)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_few_vectors.toarray() # 特征工程筛选数组"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "   * 第三行第一列中的64表示第三封电子邮件包含64个不属于词汇表的单词。\n",
    "   * 旁边的1表示词汇表中'of'单词在此电子邮件中出现一次。\n",
    "   * 旁边的2表示'and'单词出现两次,'the'没有出现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'the': 1,\n",
       " 'of': 2,\n",
       " 'and': 3,\n",
       " 'url': 4,\n",
       " 'to': 5,\n",
       " 'all': 6,\n",
       " '>': 7,\n",
       " 'in': 8,\n",
       " 'on': 9,\n",
       " 'by': 10}"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab_transformer.vocabulary_  # 以下数字是排名"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练第一个垃圾邮件分类器\n",
    "   * 转换整个数据集："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import Pipeline  # 放入流水线中\n",
    "\n",
    "preprocess_pipeline = Pipeline([\n",
    "    ('email_to_wordcount', EmailToWordCounterTransformer()),\n",
    "    ('wordcount_to_vector', WordCounterToVectorTransformer()),\n",
    "])\n",
    "\n",
    "X_train_transformed = preprocess_pipeline.fit_transform(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV]  ................................................................\n",
      "[CV] .................................... , score=0.981, total=   0.0s\n",
      "[CV]  ................................................................\n",
      "[CV] .................................... , score=0.988, total=   0.1s\n",
      "[CV]  ................................................................\n",
      "[CV] .................................... , score=0.993, total=   0.1s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n",
      "[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s\n",
      "[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s\n",
      "[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.1s finished\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.9870833333333334"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.model_selection import cross_val_score\n",
    "\n",
    "log_clf = LogisticRegression(solver = 'liblinear', random_state = 42) # 采用逻辑回归分类器\n",
    "score = cross_val_score(log_clf, X_train_transformed, y_train, cv = 3, verbose = 3)\n",
    "score.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "   * 得到分数超过98.7%，可以尝试多个模型，选择最好的模型，并使用交叉验证对它们进行微调。\n",
    "   * 在测试集上得到的精度/召回率："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "精度: 95.74%\n",
      "召回: 94.74%\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import precision_score, recall_score\n",
    "\n",
    "X_test_transformed = preprocess_pipeline.transform(X_test)\n",
    "\n",
    "log_clf = LogisticRegression(solver = 'liblinear', random_state = 42)\n",
    "log_clf.fit(X_train_transformed, y_train)\n",
    "\n",
    "y_pred = log_clf.predict(X_test_transformed)\n",
    "\n",
    "print('精度: {:.2f}%'.format(100 * precision_score(y_test, y_pred)))\n",
    "print('召回: {:.2f}%'.format(100 * recall_score(y_test, y_pred)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 总结\n",
    "   1. 加载数据并纵观数据大局\n",
    "   2. 获取邮件的组成结构\n",
    "   3. 对结构类型进行分析 发现垃圾邮件大多有HTML结构\n",
    "   4. 数据清洗，定义email对象中的HTML转换称纯文本方法\n",
    "   5. 对数据集拆分成训练集和测试集\n",
    "   6. 数据处理转换，对邮件的文本内容进行分词处理，通过nltk进行词干提取，对邮件出现的词汇进行计数统计，对所有邮件统计出了一个词汇表\n",
    "   7. 通过词汇表和邮件单词计数统计，将单词计数转化成向量矩阵\n",
    "   8. 把数据清洗和数据处理封装成两个转换器\n",
    "   9. 通过流水线来自动化处理数据\n",
    "   10. 使用逻辑回归线性分类器进行模型训练\n",
    "   11. 使用交叉验证进行微调\n",
    "   12. 在测试集上得到精度/召回率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
